# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Private base class for layers that can merge several inputs into one."""

from keras import backend
from keras.engine.base_layer import Layer
from keras.utils import tf_utils
import tensorflow.compat.v2 as tf


class _Merge(Layer):
  """Generic merge layer for elementwise merge functions.

  Used to implement `Sum`, `Average`, etc.
  """

  def __init__(self, **kwargs):
    """Initializes a Merge layer.

    Args:
      **kwargs: standard layer keyword arguments.
    """
    super(_Merge, self).__init__(**kwargs)
    self.supports_masking = True

  def _merge_function(self, inputs):
    raise NotImplementedError

  def _compute_elemwise_op_output_shape(self, shape1, shape2):
    """Computes the shape of the resultant of an elementwise operation.

    Args:
        shape1: tuple or None. Shape of the first tensor
        shape2: tuple or None. Shape of the second tensor

    Returns:
        expected output shape when an element-wise operation is
        carried out on 2 tensors with shapes shape1 and shape2.
        tuple or None.

    Raises:
        ValueError: if shape1 and shape2 are not compatible for
            element-wise operations.
    """
    if None in [shape1, shape2]:
      return None
    elif len(shape1) < len(shape2):
      return self._compute_elemwise_op_output_shape(shape2, shape1)  # pylint: disable=arguments-out-of-order
    elif not shape2:
      return shape1
    output_shape = list(shape1[:-len(shape2)])
    for i, j in zip(shape1[-len(shape2):], shape2):
      if i is None or j is None:
        output_shape.append(None)
      elif i == 1:
        output_shape.append(j)
      elif j == 1:
        output_shape.append(i)
      else:
        if i != j:
          raise ValueError(
              'Inputs have incompatible shapes. '
              f'Received shapes {shape1} and {shape2}')
        output_shape.append(i)
    return tuple(output_shape)

  @tf_utils.shape_type_conversion
  def build(self, input_shape):
    # Used purely for shape validation.
    if not isinstance(input_shape[0], tuple):
      raise ValueError(
          'A merge layer should be called on a list of inputs. '
          f'Received: input_shape={input_shape} (not a list of shapes)')
    if len(input_shape) < 2:
      raise ValueError('A merge layer should be called '
                       'on a list of at least 2 inputs. '
                       f'Got {len(input_shape)} inputs. '
                       f'Full input_shape received: {input_shape}')
    batch_sizes = {s[0] for s in input_shape if s} - {None}
    if len(batch_sizes) > 1:
      raise ValueError(
          'Cannot merge tensors with different batch sizes. '
          f'Got tensors with shapes {input_shape}')
    if input_shape[0] is None:
      output_shape = None
    else:
      output_shape = input_shape[0][1:]
    for i in range(1, len(input_shape)):
      if input_shape[i] is None:
        shape = None
      else:
        shape = input_shape[i][1:]
      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
    # If the inputs have different ranks, we have to reshape them
    # to make them broadcastable.
    if None not in input_shape and len(set(map(len, input_shape))) == 1:
      self._reshape_required = False
    else:
      self._reshape_required = True

  def call(self, inputs):
    if not isinstance(inputs, (list, tuple)):
      raise ValueError(
          'A merge layer should be called on a list of inputs. '
          f'Received: inputs={inputs} (not a list of tensors)')
    if self._reshape_required:
      reshaped_inputs = []
      input_ndims = list(map(backend.ndim, inputs))
      if None not in input_ndims:
        # If ranks of all inputs are available,
        # we simply expand each of them at axis=1
        # until all of them have the same rank.
        max_ndim = max(input_ndims)
        for x in inputs:
          x_ndim = backend.ndim(x)
          for _ in range(max_ndim - x_ndim):
            x = tf.expand_dims(x, axis=1)
          reshaped_inputs.append(x)
        return self._merge_function(reshaped_inputs)
      else:
        # Transpose all inputs so that batch size is the last dimension.
        # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
        transposed = False
        for x in inputs:
          x_ndim = backend.ndim(x)
          if x_ndim is None:
            x_shape = tf.shape(x)
            batch_size = x_shape[0]
            new_shape = backend.concatenate(
                [x_shape[1:],
                 tf.expand_dims(batch_size, axis=-1)])
            x_transposed = tf.reshape(
                x,
                tf.stack(
                    [batch_size, tf.reduce_prod(x_shape[1:])], axis=0))
            x_transposed = tf.transpose(x_transposed, perm=(1, 0))
            x_transposed = tf.reshape(x_transposed, new_shape)
            reshaped_inputs.append(x_transposed)
            transposed = True
          elif x_ndim > 1:
            dims = list(range(1, x_ndim)) + [0]
            reshaped_inputs.append(tf.transpose(x, perm=dims))
            transposed = True
          else:
            # We don't transpose inputs if they are 1D vectors or scalars.
            reshaped_inputs.append(x)
        y = self._merge_function(reshaped_inputs)
        y_ndim = backend.ndim(y)
        if transposed:
          # If inputs have been transposed, we have to transpose the output too.
          if y_ndim is None:
            y_shape = tf.shape(y)
            y_ndim = tf.shape(y_shape)[0]
            batch_size = y_shape[y_ndim - 1]
            new_shape = backend.concatenate([
                tf.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
            ])
            y = tf.reshape(y, (-1, batch_size))
            y = tf.transpose(y, perm=(1, 0))
            y = tf.reshape(y, new_shape)
          elif y_ndim > 1:
            dims = [y_ndim - 1] + list(range(y_ndim - 1))
            y = tf.transpose(y, perm=dims)
        return y
    else:
      return self._merge_function(inputs)

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    if input_shape[0] is None:
      output_shape = None
    else:
      output_shape = input_shape[0][1:]
    for i in range(1, len(input_shape)):
      if input_shape[i] is None:
        shape = None
      else:
        shape = input_shape[i][1:]
      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
    batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
    if len(batch_sizes) == 1:
      output_shape = (list(batch_sizes)[0],) + output_shape
    else:
      output_shape = (None,) + output_shape
    return output_shape

  def compute_mask(self, inputs, mask=None):
    if mask is None:
      return None
    if not isinstance(mask, (tuple, list)):
      raise ValueError(f'`mask` should be a list. Received: mask={mask}')
    if not isinstance(inputs, (tuple, list)):
      raise ValueError(f'`inputs` should be a list. Received: inputs={inputs}')
    if len(mask) != len(inputs):
      raise ValueError(
          'The lists `inputs` and `mask` should have the same length. '
          f'Received: inputs={inputs} of length {len(inputs)}, and '
          f'mask={mask} of length {len(mask)}')
    if all(m is None for m in mask):
      return None
    masks = [tf.expand_dims(m, axis=0) for m in mask if m is not None]
    return backend.all(
        backend.concatenate(masks, axis=0), axis=0, keepdims=False)

  def get_config(self):  # pylint: disable=useless-super-delegation
    return super().get_config()
